{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "markdown"
    }
   },
   "source": [
    "# Corrective RAG (CRAG) Implementation\n",
    "\n",
    "In this notebook, I implement Corrective RAG - an advanced approach that dynamically evaluates retrieved information and corrects the retrieval process when necessary, using web search as a fallback.\n",
    "\n",
    "CRAG improves on traditional RAG by:\n",
    "\n",
    "- Evaluating retrieved content before using it\n",
    "- Dynamically switching between knowledge sources based on relevance\n",
    "- Correcting the retrieval with web search when local knowledge is insufficient\n",
    "- Combining information from multiple sources when appropriate"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting Up the Environment\n",
    "We begin by importing necessary libraries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import json\n",
    "import fitz  # PyMuPDF\n",
    "from openai import OpenAI\n",
    "import requests\n",
    "from typing import List, Dict, Tuple, Any\n",
    "import re\n",
    "from urllib.parse import quote_plus\n",
    "import time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting Up the OpenAI API Client\n",
    "We initialize the OpenAI client to generate embeddings and responses."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the OpenAI client with the base URL and API key\n",
    "client = OpenAI(\n",
    "    base_url=\"https://api.studio.nebius.com/v1/\",\n",
    "    api_key=os.getenv(\"OPENAI_API_KEY\")  # Retrieve the API key from environment variables\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Document Processing Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_text_from_pdf(pdf_path):\n",
    "    \"\"\"\n",
    "    Extract text content from a PDF file.\n",
    "    \n",
    "    Args:\n",
    "        pdf_path (str): Path to the PDF file\n",
    "        \n",
    "    Returns:\n",
    "        str: Extracted text content\n",
    "    \"\"\"\n",
    "    print(f\"Extracting text from {pdf_path}...\")\n",
    "    \n",
    "    # Open the PDF file\n",
    "    pdf = fitz.open(pdf_path)\n",
    "    text = \"\"\n",
    "    \n",
    "    # Iterate through each page in the PDF\n",
    "    for page_num in range(len(pdf)):\n",
    "        page = pdf[page_num]\n",
    "        # Extract text from the current page and append it to the text variable\n",
    "        text += page.get_text()\n",
    "    \n",
    "    return text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def chunk_text(text, chunk_size=1000, overlap=200):\n",
    "    \"\"\"\n",
    "    Split text into overlapping chunks for efficient retrieval and processing.\n",
    "    \n",
    "    This function divides a large text into smaller, manageable chunks with\n",
    "    specified overlap between consecutive chunks. Chunking is critical for RAG\n",
    "    systems as it allows for more precise retrieval of relevant information.\n",
    "    \n",
    "    Args:\n",
    "        text (str): Input text to be chunked\n",
    "        chunk_size (int): Maximum size of each chunk in characters\n",
    "        overlap (int): Number of overlapping characters between consecutive chunks\n",
    "                       to maintain context across chunk boundaries\n",
    "        \n",
    "    Returns:\n",
    "        List[Dict]: List of text chunks, each containing:\n",
    "                   - text: The chunk content\n",
    "                   - metadata: Dictionary with positional information and source type\n",
    "    \"\"\"\n",
    "    chunks = []\n",
    "    \n",
    "    # Iterate through the text with a sliding window approach\n",
    "    # Moving by (chunk_size - overlap) ensures proper overlap between chunks\n",
    "    for i in range(0, len(text), chunk_size - overlap):\n",
    "        # Extract the current chunk, limited by chunk_size\n",
    "        chunk_text = text[i:i + chunk_size]\n",
    "        \n",
    "        # Only add non-empty chunks\n",
    "        if chunk_text:\n",
    "            chunks.append({\n",
    "                \"text\": chunk_text,  # The actual text content\n",
    "                \"metadata\": {\n",
    "                    \"start_pos\": i,  # Starting position in the original text\n",
    "                    \"end_pos\": i + len(chunk_text),  # Ending position\n",
    "                    \"source_type\": \"document\"  # Indicates the source of this text\n",
    "                }\n",
    "            })\n",
    "    \n",
    "    print(f\"Created {len(chunks)} text chunks\")\n",
    "    return chunks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simple Vector Store Implementation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleVectorStore:\n",
    "    \"\"\"\n",
    "    A simple vector store implementation using NumPy.\n",
    "    \"\"\"\n",
    "    def __init__(self):\n",
    "        # Initialize lists to store vectors, texts, and metadata\n",
    "        self.vectors = []\n",
    "        self.texts = []\n",
    "        self.metadata = []\n",
    "    \n",
    "    def add_item(self, text, embedding, metadata=None):\n",
    "        \"\"\"\n",
    "        Add an item to the vector store.\n",
    "        \n",
    "        Args:\n",
    "            text (str): The text content\n",
    "            embedding (List[float]): The embedding vector\n",
    "            metadata (Dict, optional): Additional metadata\n",
    "        \"\"\"\n",
    "        # Append the embedding, text, and metadata to their respective lists\n",
    "        self.vectors.append(np.array(embedding))\n",
    "        self.texts.append(text)\n",
    "        self.metadata.append(metadata or {})\n",
    "    \n",
    "    def add_items(self, items, embeddings):\n",
    "        \"\"\"\n",
    "        Add multiple items to the vector store.\n",
    "        \n",
    "        Args:\n",
    "            items (List[Dict]): List of items with text and metadata\n",
    "            embeddings (List[List[float]]): List of embedding vectors\n",
    "        \"\"\"\n",
    "        # Iterate over items and embeddings and add them to the store\n",
    "        for i, (item, embedding) in enumerate(zip(items, embeddings)):\n",
    "            self.add_item(\n",
    "                text=item[\"text\"],\n",
    "                embedding=embedding,\n",
    "                metadata=item.get(\"metadata\", {})\n",
    "            )\n",
    "    \n",
    "    def similarity_search(self, query_embedding, k=5):\n",
    "        \"\"\"\n",
    "        Find the most similar items to a query embedding.\n",
    "        \n",
    "        Args:\n",
    "            query_embedding (List[float]): Query embedding vector\n",
    "            k (int): Number of results to return\n",
    "            \n",
    "        Returns:\n",
    "            List[Dict]: Top k most similar items\n",
    "        \"\"\"\n",
    "        # Return an empty list if there are no vectors in the store\n",
    "        if not self.vectors:\n",
    "            return []\n",
    "        \n",
    "        # Convert query embedding to numpy array\n",
    "        query_vector = np.array(query_embedding)\n",
    "        \n",
    "        # Calculate similarities using cosine similarity\n",
    "        similarities = []\n",
    "        for i, vector in enumerate(self.vectors):\n",
    "            similarity = np.dot(query_vector, vector) / (np.linalg.norm(query_vector) * np.linalg.norm(vector))\n",
    "            similarities.append((i, similarity))\n",
    "        \n",
    "        # Sort by similarity (descending)\n",
    "        similarities.sort(key=lambda x: x[1], reverse=True)\n",
    "        \n",
    "        # Return top k results\n",
    "        results = []\n",
    "        for i in range(min(k, len(similarities))):\n",
    "            idx, score = similarities[i]\n",
    "            results.append({\n",
    "                \"text\": self.texts[idx],\n",
    "                \"metadata\": self.metadata[idx],\n",
    "                \"similarity\": float(score)\n",
    "            })\n",
    "        \n",
    "        return results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating Embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_embeddings(texts, model=\"text-embedding-3-small\"):\n",
    "    \"\"\"\n",
    "    Create vector embeddings for text inputs using OpenAI's embedding models.\n",
    "    \n",
    "    Embeddings are dense vector representations of text that capture semantic meaning,\n",
    "    allowing for similarity comparisons. In RAG systems, embeddings are essential\n",
    "    for matching queries with relevant document chunks.\n",
    "    \n",
    "    Args:\n",
    "        texts (str or List[str]): Input text(s) to be embedded. Can be a single string\n",
    "                                  or a list of strings.\n",
    "        model (str): The embedding model name to use. Defaults to \"text-embedding-3-small\".\n",
    "        \n",
    "    Returns:\n",
    "        List[List[float]]: If input is a list, returns a list of embedding vectors.\n",
    "                          If input is a single string, returns a single embedding vector.\n",
    "    \"\"\"\n",
    "    # Handle both single string and list inputs by converting single strings to a list\n",
    "    input_texts = texts if isinstance(texts, list) else [texts]\n",
    "    \n",
    "    # Process in batches to avoid API rate limits and payload size restrictions\n",
    "    # OpenAI API typically has limits on request size and rate\n",
    "    batch_size = 100\n",
    "    all_embeddings = []\n",
    "    \n",
    "    # Process each batch of texts\n",
    "    for i in range(0, len(input_texts), batch_size):\n",
    "        # Extract the current batch of texts\n",
    "        batch = input_texts[i:i + batch_size]\n",
    "        \n",
    "        # Make API call to generate embeddings for the current batch\n",
    "        response = client.embeddings.create(\n",
    "            model=model,\n",
    "            input=batch\n",
    "        )\n",
    "        \n",
    "        # Extract the embedding vectors from the response\n",
    "        batch_embeddings = [item.embedding for item in response.data]\n",
    "        all_embeddings.extend(batch_embeddings)\n",
    "    \n",
    "    # If the original input was a single string, return just the first embedding\n",
    "    if isinstance(texts, str):\n",
    "        return all_embeddings[0]\n",
    "    \n",
    "    # Otherwise return the full list of embeddings\n",
    "    return all_embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Document Processing Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_document(pdf_path, chunk_size=1000, chunk_overlap=200):\n",
    "    \"\"\"\n",
    "    Process a document into a vector store.\n",
    "    \n",
    "    Args:\n",
    "        pdf_path (str): Path to the PDF file\n",
    "        chunk_size (int): Size of each chunk in characters\n",
    "        chunk_overlap (int): Overlap between chunks in characters\n",
    "        \n",
    "    Returns:\n",
    "        SimpleVectorStore: Vector store containing document chunks\n",
    "    \"\"\"\n",
    "    # Extract text from the PDF file\n",
    "    text = extract_text_from_pdf(pdf_path)\n",
    "    \n",
    "    # Split the extracted text into chunks with specified size and overlap\n",
    "    chunks = chunk_text(text, chunk_size, chunk_overlap)\n",
    "    \n",
    "    # Create embeddings for each chunk of text\n",
    "    print(\"Creating embeddings for chunks...\")\n",
    "    chunk_texts = [chunk[\"text\"] for chunk in chunks]\n",
    "    chunk_embeddings = create_embeddings(chunk_texts)\n",
    "    \n",
    "    # Initialize a new vector store\n",
    "    vector_store = SimpleVectorStore()\n",
    "    \n",
    "    # Add the chunks and their embeddings to the vector store\n",
    "    vector_store.add_items(chunks, chunk_embeddings)\n",
    "    \n",
    "    print(f\"Vector store created with {len(chunks)} chunks\")\n",
    "    return vector_store"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Relevance Evaluation Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_document_relevance(query, document):\n",
    "    \"\"\"\n",
    "    Evaluate the relevance of a document to a query.\n",
    "    \n",
    "    Args:\n",
    "        query (str): User query\n",
    "        document (str): Document text\n",
    "        \n",
    "    Returns:\n",
    "        float: Relevance score (0-1)\n",
    "    \"\"\"\n",
    "    # Define the system prompt to instruct the model on how to evaluate relevance\n",
    "    system_prompt = \"\"\"\n",
    "    You are an expert at evaluating document relevance. \n",
    "    Rate how relevant the given document is to the query on a scale from 0 to 1.\n",
    "    0 means completely irrelevant, 1 means perfectly relevant.\n",
    "    Provide ONLY the score as a float between 0 and 1.\n",
    "    \"\"\"\n",
    "    \n",
    "    # Define the user prompt with the query and document\n",
    "    user_prompt = f\"Query: {query}\\n\\nDocument: {document}\"\n",
    "    \n",
    "    try:\n",
    "        # Make a request to the OpenAI API to evaluate the relevance\n",
    "        response = client.chat.completions.create(\n",
    "            model=\"gpt-3.5-turbo\",  # Specify the model to use\n",
    "            messages=[\n",
    "                {\"role\": \"system\", \"content\": system_prompt},  # System message to guide the assistant\n",
    "                {\"role\": \"user\", \"content\": user_prompt}  # User message with the query and document\n",
    "            ],\n",
    "            temperature=0,  # Set the temperature for response generation\n",
    "            max_tokens=5  # Very short response needed\n",
    "        )\n",
    "        \n",
    "        # Extract the score from the response\n",
    "        score_text = response.choices[0].message.content.strip()\n",
    "        # Use regex to find the float value in the response\n",
    "        score_match = re.search(r'(\\d+(\\.\\d+)?)', score_text)\n",
    "        if score_match:\n",
    "            return float(score_match.group(1))  # Return the extracted score as a float\n",
    "        return 0.5  # Default to middle value if parsing fails\n",
    "    \n",
    "    except Exception as e:\n",
    "        # Print the error message and return a default value on error\n",
    "        print(f\"Error evaluating document relevance: {e}\")\n",
    "        return 0.5  # Default to middle value on error"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Web Search Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def duck_duck_go_search(query, num_results=3):\n",
    "    \"\"\"\n",
    "    Perform a web search using DuckDuckGo.\n",
    "    \n",
    "    Args:\n",
    "        query (str): Search query\n",
    "        num_results (int): Number of results to return\n",
    "        \n",
    "    Returns:\n",
    "        Tuple[str, List[Dict]]: Combined search results text and source metadata\n",
    "    \"\"\"\n",
    "    # Encode the query for URL\n",
    "    encoded_query = quote_plus(query)\n",
    "    \n",
    "    # DuckDuckGo search API endpoint (unofficial)\n",
    "    url = f\"https://api.duckduckgo.com/?q={encoded_query}&format=json\"\n",
    "    \n",
    "    try:\n",
    "        # Perform the web search request\n",
    "        response = requests.get(url, headers={\n",
    "            \"User-Agent\": \"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36\"\n",
    "        })\n",
    "        data = response.json()\n",
    "        \n",
    "        # Initialize variables to store results text and sources\n",
    "        results_text = \"\"\n",
    "        sources = []\n",
    "        \n",
    "        # Add abstract if available\n",
    "        if data.get(\"AbstractText\"):\n",
    "            results_text += f\"{data['AbstractText']}\\n\\n\"\n",
    "            sources.append({\n",
    "                \"title\": data.get(\"AbstractSource\", \"Wikipedia\"),\n",
    "                \"url\": data.get(\"AbstractURL\", \"\")\n",
    "            })\n",
    "        \n",
    "        # Add related topics\n",
    "        for topic in data.get(\"RelatedTopics\", [])[:num_results]:\n",
    "            if \"Text\" in topic and \"FirstURL\" in topic:\n",
    "                results_text += f\"{topic['Text']}\\n\\n\"\n",
    "                sources.append({\n",
    "                    \"title\": topic.get(\"Text\", \"\").split(\" - \")[0],\n",
    "                    \"url\": topic.get(\"FirstURL\", \"\")\n",
    "                })\n",
    "        \n",
    "        return results_text, sources\n",
    "    \n",
    "    except Exception as e:\n",
    "        # Print error message if the main search fails\n",
    "        print(f\"Error performing web search: {e}\")\n",
    "        \n",
    "        # Fallback to a backup search API\n",
    "        try:\n",
    "            backup_url = f\"https://serpapi.com/search.json?q={encoded_query}&engine=duckduckgo\"\n",
    "            response = requests.get(backup_url)\n",
    "            data = response.json()\n",
    "            \n",
    "            # Initialize variables to store results text and sources\n",
    "            results_text = \"\"\n",
    "            sources = []\n",
    "            \n",
    "            # Extract results from the backup API\n",
    "            for result in data.get(\"organic_results\", [])[:num_results]:\n",
    "                results_text += f\"{result.get('title', '')}: {result.get('snippet', '')}\\n\\n\"\n",
    "                sources.append({\n",
    "                    \"title\": result.get(\"title\", \"\"),\n",
    "                    \"url\": result.get(\"link\", \"\")\n",
    "                })\n",
    "            \n",
    "            return results_text, sources\n",
    "        except Exception as backup_error:\n",
    "            # Print error message if the backup search also fails\n",
    "            print(f\"Backup search also failed: {backup_error}\")\n",
    "            return \"Failed to retrieve search results.\", []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rewrite_search_query(query):\n",
    "    \"\"\"\n",
    "    Rewrite a query to be more suitable for web search.\n",
    "    \n",
    "    Args:\n",
    "        query (str): Original query\n",
    "        \n",
    "    Returns:\n",
    "        str: Rewritten query\n",
    "    \"\"\"\n",
    "    # Define the system prompt to instruct the model on how to rewrite the query\n",
    "    system_prompt = \"\"\"\n",
    "    You are an expert at creating effective search queries.\n",
    "    Rewrite the given query to make it more suitable for a web search engine.\n",
    "    Focus on keywords and facts, remove unnecessary words, and make it concise.\n",
    "    \"\"\"\n",
    "    \n",
    "    try:\n",
    "        # Make a request to the OpenAI API to rewrite the query\n",
    "        response = client.chat.completions.create(\n",
    "            model=\"gpt-3.5-turbo\",  # Specify the model to use\n",
    "            messages=[\n",
    "                {\"role\": \"system\", \"content\": system_prompt},  # System message to guide the assistant\n",
    "                {\"role\": \"user\", \"content\": f\"Original query: {query}\\n\\nRewritten query:\"}  # User message with the original query\n",
    "            ],\n",
    "            temperature=0.3,  # Set the temperature for response generation\n",
    "            max_tokens=50  # Limit the response length\n",
    "        )\n",
    "        \n",
    "        # Return the rewritten query from the response\n",
    "        return response.choices[0].message.content.strip()\n",
    "    except Exception as e:\n",
    "        # Print the error message and return the original query on error\n",
    "        print(f\"Error rewriting search query: {e}\")\n",
    "        return query  # Return original query on error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def perform_web_search(query):\n",
    "    \"\"\"\n",
    "    Perform web search with query rewriting.\n",
    "    \n",
    "    Args:\n",
    "        query (str): Original user query\n",
    "        \n",
    "    Returns:\n",
    "        Tuple[str, List[Dict]]: Search results text and source metadata\n",
    "    \"\"\"\n",
    "    # Rewrite the query to improve search results\n",
    "    rewritten_query = rewrite_search_query(query)\n",
    "    print(f\"Rewritten search query: {rewritten_query}\")\n",
    "    \n",
    "    # Perform the web search using the rewritten query\n",
    "    results_text, sources = duck_duck_go_search(rewritten_query)\n",
    "    \n",
    "    # Return the search results text and source metadata\n",
    "    return results_text, sources"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Knowledge Refinement Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def refine_knowledge(text):\n",
    "    \"\"\"\n",
    "    Extract and refine key information from text.\n",
    "    \n",
    "    Args:\n",
    "        text (str): Input text to refine\n",
    "        \n",
    "    Returns:\n",
    "        str: Refined key points from the text\n",
    "    \"\"\"\n",
    "    # Define the system prompt to instruct the model on how to extract key information\n",
    "    system_prompt = \"\"\"\n",
    "    Extract the key information from the following text as a set of clear, concise bullet points.\n",
    "    Focus on the most relevant facts and important details.\n",
    "    Format your response as a bulleted list with each point on a new line starting with \"• \".\n",
    "    \"\"\"\n",
    "    \n",
    "    try:\n",
    "        # Make a request to the OpenAI API to refine the text\n",
    "        response = client.chat.completions.create(\n",
    "            model=\"gpt-3.5-turbo\",  # Specify the model to use\n",
    "            messages=[\n",
    "                {\"role\": \"system\", \"content\": system_prompt},  # System message to guide the assistant\n",
    "                {\"role\": \"user\", \"content\": f\"Text to refine:\\n\\n{text}\"}  # User message with the text to refine\n",
    "            ],\n",
    "            temperature=0.3  # Set the temperature for response generation\n",
    "        )\n",
    "        \n",
    "        # Return the refined key points from the response\n",
    "        return response.choices[0].message.content.strip()\n",
    "    except Exception as e:\n",
    "        # Print the error message and return the original text on error\n",
    "        print(f\"Error refining knowledge: {e}\")\n",
    "        return text  # Return original text on error"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Core CRAG Process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def crag_process(query, vector_store, k=3):\n",
    "    \"\"\"\n",
    "    Run the Corrective RAG process.\n",
    "    \n",
    "    Args:\n",
    "        query (str): User query\n",
    "        vector_store (SimpleVectorStore): Vector store with document chunks\n",
    "        k (int): Number of initial documents to retrieve\n",
    "        \n",
    "    Returns:\n",
    "        Dict: Process results including response and debug info\n",
    "    \"\"\"\n",
    "    print(f\"\\n=== Processing query with CRAG: {query} ===\\n\")\n",
    "    \n",
    "    # Step 1: Create query embedding and retrieve documents\n",
    "    print(\"Retrieving initial documents...\")\n",
    "    query_embedding = create_embeddings(query)\n",
    "    retrieved_docs = vector_store.similarity_search(query_embedding, k=k)\n",
    "    \n",
    "    # Step 2: Evaluate document relevance\n",
    "    print(\"Evaluating document relevance...\")\n",
    "    relevance_scores = []\n",
    "    for doc in retrieved_docs:\n",
    "        score = evaluate_document_relevance(query, doc[\"text\"])\n",
    "        relevance_scores.append(score)\n",
    "        doc[\"relevance\"] = score\n",
    "        print(f\"Document scored {score:.2f} relevance\")\n",
    "    \n",
    "    # Step 3: Determine action based on best relevance score\n",
    "    max_score = max(relevance_scores) if relevance_scores else 0\n",
    "    best_doc_idx = relevance_scores.index(max_score) if relevance_scores else -1\n",
    "    \n",
    "    # Track sources for attribution\n",
    "    sources = []\n",
    "    final_knowledge = \"\"\n",
    "    \n",
    "    # Step 4: Execute the appropriate knowledge acquisition strategy\n",
    "    if max_score > 0.7:\n",
    "        # Case 1: High relevance - Use document directly\n",
    "        print(f\"High relevance ({max_score:.2f}) - Using document directly\")\n",
    "        best_doc = retrieved_docs[best_doc_idx][\"text\"]\n",
    "        final_knowledge = best_doc\n",
    "        sources.append({\n",
    "            \"title\": \"Document\",\n",
    "            \"url\": \"\"\n",
    "        })\n",
    "        \n",
    "    elif max_score < 0.3:\n",
    "        # Case 2: Low relevance - Use web search\n",
    "        print(f\"Low relevance ({max_score:.2f}) - Performing web search\")\n",
    "        web_results, web_sources = perform_web_search(query)\n",
    "        final_knowledge = refine_knowledge(web_results)\n",
    "        sources.extend(web_sources)\n",
    "        \n",
    "    else:\n",
    "        # Case 3: Medium relevance - Combine document with web search\n",
    "        print(f\"Medium relevance ({max_score:.2f}) - Combining document with web search\")\n",
    "        best_doc = retrieved_docs[best_doc_idx][\"text\"]\n",
    "        refined_doc = refine_knowledge(best_doc)\n",
    "        \n",
    "        # Get web results\n",
    "        web_results, web_sources = perform_web_search(query)\n",
    "        refined_web = refine_knowledge(web_results)\n",
    "        \n",
    "        # Combine knowledge\n",
    "        final_knowledge = f\"From document:\\n{refined_doc}\\n\\nFrom web search:\\n{refined_web}\"\n",
    "        \n",
    "        # Add sources\n",
    "        sources.append({\n",
    "            \"title\": \"Document\",\n",
    "            \"url\": \"\"\n",
    "        })\n",
    "        sources.extend(web_sources)\n",
    "    \n",
    "    # Step 5: Generate final response\n",
    "    print(\"Generating final response...\")\n",
    "    response = generate_response(query, final_knowledge, sources)\n",
    "    \n",
    "    # Return comprehensive results\n",
    "    return {\n",
    "        \"query\": query,\n",
    "        \"response\": response,\n",
    "        \"retrieved_docs\": retrieved_docs,\n",
    "        \"relevance_scores\": relevance_scores,\n",
    "        \"max_relevance\": max_score,\n",
    "        \"final_knowledge\": final_knowledge,\n",
    "        \"sources\": sources\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Response Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_response(query, knowledge, sources):\n",
    "    \"\"\"\n",
    "    Generate a response based on the query and knowledge.\n",
    "    \n",
    "    Args:\n",
    "        query (str): User query\n",
    "        knowledge (str): Knowledge to base the response on\n",
    "        sources (List[Dict]): List of sources with title and URL\n",
    "        \n",
    "    Returns:\n",
    "        str: Generated response\n",
    "    \"\"\"\n",
    "    # Format sources for inclusion in prompt\n",
    "    sources_text = \"\"\n",
    "    for source in sources:\n",
    "        title = source.get(\"title\", \"Unknown Source\")\n",
    "        url = source.get(\"url\", \"\")\n",
    "        if url:\n",
    "            sources_text += f\"- {title}: {url}\\n\"\n",
    "        else:\n",
    "            sources_text += f\"- {title}\\n\"\n",
    "    \n",
    "    # Define the system prompt to instruct the model on how to generate the response\n",
    "    system_prompt = \"\"\"\n",
    "    You are a helpful AI assistant. Generate a comprehensive, informative response to the query based on the provided knowledge.\n",
    "    Include all relevant information while keeping your answer clear and concise.\n",
    "    If the knowledge doesn't fully answer the query, acknowledge this limitation.\n",
    "    Include source attribution at the end of your response.\n",
    "    \"\"\"\n",
    "    \n",
    "    # Define the user prompt with the query, knowledge, and sources\n",
    "    user_prompt = f\"\"\"\n",
    "    Query: {query}\n",
    "    \n",
    "    Knowledge:\n",
    "    {knowledge}\n",
    "    \n",
    "    Sources:\n",
    "    {sources_text}\n",
    "    \n",
    "    Please provide an informative response to the query based on this information.\n",
    "    Include the sources at the end of your response.\n",
    "    \"\"\"\n",
    "    \n",
    "    try:\n",
    "        # Make a request to the OpenAI API to generate the response\n",
    "        response = client.chat.completions.create(\n",
    "            model=\"gpt-4\",  # Using GPT-4 for high-quality responses\n",
    "            messages=[\n",
    "                {\"role\": \"system\", \"content\": system_prompt},\n",
    "                {\"role\": \"user\", \"content\": user_prompt}\n",
    "            ],\n",
    "            temperature=0.2\n",
    "        )\n",
    "        \n",
    "        # Return the generated response\n",
    "        return response.choices[0].message.content.strip()\n",
    "    except Exception as e:\n",
    "        # Print the error message and return an error response\n",
    "        print(f\"Error generating response: {e}\")\n",
    "        return f\"I apologize, but I encountered an error while generating a response to your query: '{query}'. The error was: {str(e)}\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_crag_response(query, response, reference_answer=None):\n",
    "    \"\"\"\n",
    "    Evaluate the quality of a CRAG response.\n",
    "    \n",
    "    Args:\n",
    "        query (str): User query\n",
    "        response (str): Generated response\n",
    "        reference_answer (str, optional): Reference answer for comparison\n",
    "        \n",
    "    Returns:\n",
    "        Dict: Evaluation metrics\n",
    "    \"\"\"\n",
    "    # System prompt for the evaluation criteria\n",
    "    system_prompt = \"\"\"\n",
    "    You are an expert at evaluating the quality of responses to questions.\n",
    "    Please evaluate the provided response based on the following criteria:\n",
    "    \n",
    "    1. Relevance (0-10): How directly does the response address the query?\n",
    "    2. Accuracy (0-10): How factually correct is the information?\n",
    "    3. Completeness (0-10): How thoroughly does the response answer all aspects of the query?\n",
    "    4. Clarity (0-10): How clear and easy to understand is the response?\n",
    "    5. Source Quality (0-10): How well does the response cite relevant sources?\n",
    "    \n",
    "    Return your evaluation as a JSON object with scores for each criterion and a brief explanation for each score.\n",
    "    Also include an \"overall_score\" (0-10) and a brief \"summary\" of your evaluation.\n",
    "    \"\"\"\n",
    "    \n",
    "    # User prompt with the query and response to be evaluated\n",
    "    user_prompt = f\"\"\"\n",
    "    Query: {query}\n",
    "    \n",
    "    Response to evaluate:\n",
    "    {response}\n",
    "    \"\"\"\n",
    "    \n",
    "    # Include reference answer in the prompt if provided\n",
    "    if reference_answer:\n",
    "        user_prompt += f\"\"\"\n",
    "    Reference answer (for comparison):\n",
    "    {reference_answer}\n",
    "    \"\"\"\n",
    "    \n",
    "    try:\n",
    "        # Request evaluation from the GPT-4 model\n",
    "        evaluation_response = client.chat.completions.create(\n",
    "            model=\"gpt-4\",\n",
    "            messages=[\n",
    "                {\"role\": \"system\", \"content\": system_prompt},\n",
    "                {\"role\": \"user\", \"content\": user_prompt}\n",
    "            ],\n",
    "            response_format={\"type\": \"json_object\"},\n",
    "            temperature=0\n",
    "        )\n",
    "        \n",
    "        # Parse the evaluation response\n",
    "        evaluation = json.loads(evaluation_response.choices[0].message.content)\n",
    "        return evaluation\n",
    "    except Exception as e:\n",
    "        # Handle any errors during the evaluation process\n",
    "        print(f\"Error evaluating response: {e}\")\n",
    "        return {\n",
    "            \"error\": str(e),\n",
    "            \"overall_score\": 0,\n",
    "            \"summary\": \"Evaluation failed due to an error.\"\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compare_crag_vs_standard_rag(query, vector_store, reference_answer=None):\n",
    "    \"\"\"\n",
    "    Compare CRAG against standard RAG for a query.\n",
    "    \n",
    "    Args:\n",
    "        query (str): User query\n",
    "        vector_store (SimpleVectorStore): Vector store with document chunks\n",
    "        reference_answer (str, optional): Reference answer for comparison\n",
    "        \n",
    "    Returns:\n",
    "        Dict: Comparison results\n",
    "    \"\"\"\n",
    "    # Run CRAG process\n",
    "    print(\"\\n=== Running CRAG ===\")\n",
    "    crag_result = crag_process(query, vector_store)\n",
    "    crag_response = crag_result[\"response\"]\n",
    "    \n",
    "    # Run standard RAG (directly retrieve and respond)\n",
    "    print(\"\\n=== Running standard RAG ===\")\n",
    "    query_embedding = create_embeddings(query)\n",
    "    retrieved_docs = vector_store.similarity_search(query_embedding, k=3)\n",
    "    combined_text = \"\\n\\n\".join([doc[\"text\"] for doc in retrieved_docs])\n",
    "    standard_sources = [{\"title\": \"Document\", \"url\": \"\"}]\n",
    "    standard_response = generate_response(query, combined_text, standard_sources)\n",
    "    \n",
    "    # Evaluate both approaches\n",
    "    print(\"\\n=== Evaluating CRAG response ===\")\n",
    "    crag_eval = evaluate_crag_response(query, crag_response, reference_answer)\n",
    "    \n",
    "    print(\"\\n=== Evaluating standard RAG response ===\")\n",
    "    standard_eval = evaluate_crag_response(query, standard_response, reference_answer)\n",
    "    \n",
    "    # Compare approaches\n",
    "    print(\"\\n=== Comparing approaches ===\")\n",
    "    comparison = compare_responses(query, crag_response, standard_response, reference_answer)\n",
    "    \n",
    "    return {\n",
    "        \"query\": query,\n",
    "        \"crag_response\": crag_response,\n",
    "        \"standard_response\": standard_response,\n",
    "        \"reference_answer\": reference_answer,\n",
    "        \"crag_evaluation\": crag_eval,\n",
    "        \"standard_evaluation\": standard_eval,\n",
    "        \"comparison\": comparison\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compare_responses(query, crag_response, standard_response, reference_answer=None):\n",
    "    \"\"\"\n",
    "    Compare CRAG and standard RAG responses.\n",
    "    \n",
    "    Args:\n",
    "        query (str): User query\n",
    "        crag_response (str): CRAG response\n",
    "        standard_response (str): Standard RAG response\n",
    "        reference_answer (str, optional): Reference answer\n",
    "        \n",
    "    Returns:\n",
    "        str: Comparison analysis\n",
    "    \"\"\"\n",
    "    # System prompt for comparing the two approaches\n",
    "    system_prompt = \"\"\"\n",
    "    You are an expert evaluator comparing two response generation approaches:\n",
    "    \n",
    "    1. CRAG (Corrective RAG): A system that evaluates document relevance and dynamically switches to web search when needed.\n",
    "    2. Standard RAG: A system that directly retrieves documents based on embedding similarity and uses them for response generation.\n",
    "    \n",
    "    Compare the responses from these two systems based on:\n",
    "    - Accuracy and factual correctness\n",
    "    - Relevance to the query\n",
    "    - Completeness of the answer\n",
    "    - Clarity and organization\n",
    "    - Source attribution quality\n",
    "    \n",
    "    Explain which approach performed better for this specific query and why.\n",
    "    \"\"\"\n",
    "    \n",
    "    # User prompt with the query and responses to be compared\n",
    "    user_prompt = f\"\"\"\n",
    "    Query: {query}\n",
    "    \n",
    "    CRAG Response:\n",
    "    {crag_response}\n",
    "    \n",
    "    Standard RAG Response:\n",
    "    {standard_response}\n",
    "    \"\"\"\n",
    "    \n",
    "    # Include reference answer in the prompt if provided\n",
    "    if reference_answer:\n",
    "        user_prompt += f\"\"\"\n",
    "    Reference Answer:\n",
    "    {reference_answer}\n",
    "    \"\"\"\n",
    "    \n",
    "    try:\n",
    "        # Request comparison from the GPT-4 model\n",
    "        response = client.chat.completions.create(\n",
    "            model=\"gpt-4\",\n",
    "            messages=[\n",
    "                {\"role\": \"system\", \"content\": system_prompt},\n",
    "                {\"role\": \"user\", \"content\": user_prompt}\n",
    "            ],\n",
    "            temperature=0\n",
    "        )\n",
    "        \n",
    "        # Return the comparison analysis\n",
    "        return response.choices[0].message.content.strip()\n",
    "    except Exception as e:\n",
    "        # Handle any errors during the comparison process\n",
    "        print(f\"Error comparing responses: {e}\")\n",
    "        return f\"Error comparing responses: {str(e)}\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Complete Evaluation Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_crag_evaluation(pdf_path, test_queries, reference_answers=None):\n",
    "    \"\"\"\n",
    "    Run a complete evaluation of CRAG with multiple test queries.\n",
    "    \n",
    "    Args:\n",
    "        pdf_path (str): Path to the PDF document\n",
    "        test_queries (List[str]): List of test queries\n",
    "        reference_answers (List[str], optional): Reference answers for queries\n",
    "        \n",
    "    Returns:\n",
    "        Dict: Complete evaluation results\n",
    "    \"\"\"\n",
    "    # Process document and create vector store\n",
    "    vector_store = process_document(pdf_path)\n",
    "    \n",
    "    results = []\n",
    "    \n",
    "    for i, query in enumerate(test_queries):\n",
    "        print(f\"\\n\\n===== Evaluating Query {i+1}/{len(test_queries)} =====\")\n",
    "        print(f\"Query: {query}\")\n",
    "        \n",
    "        # Get reference answer if available\n",
    "        reference = None\n",
    "        if reference_answers and i < len(reference_answers):\n",
    "            reference = reference_answers[i]\n",
    "        \n",
    "        # Run comparison between CRAG and standard RAG\n",
    "        result = compare_crag_vs_standard_rag(query, vector_store, reference)\n",
    "        results.append(result)\n",
    "        \n",
    "        # Display comparison results\n",
    "        print(\"\\n=== Comparison ===\")\n",
    "        print(result[\"comparison\"])\n",
    "    \n",
    "    # Generate overall analysis from individual results\n",
    "    overall_analysis = generate_overall_analysis(results)\n",
    "    \n",
    "    return {\n",
    "        \"results\": results,\n",
    "        \"overall_analysis\": overall_analysis\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_overall_analysis(results):\n",
    "    \"\"\"\n",
    "    Generate an overall analysis of evaluation results.\n",
    "    \n",
    "    Args:\n",
    "        results (List[Dict]): Results from individual query evaluations\n",
    "        \n",
    "    Returns:\n",
    "        str: Overall analysis\n",
    "    \"\"\"\n",
    "    # System prompt for the analysis\n",
    "    system_prompt = \"\"\"\n",
    "    You are an expert at evaluating information retrieval and response generation systems.\n",
    "    Based on multiple test queries, provide an overall analysis comparing CRAG (Corrective RAG) \n",
    "    with standard RAG.\n",
    "    \n",
    "    Focus on:\n",
    "    1. When CRAG performs better and why\n",
    "    2. When standard RAG performs better and why\n",
    "    3. The overall strengths and weaknesses of each approach\n",
    "    4. Recommendations for when to use each approach\n",
    "    \"\"\"\n",
    "    \n",
    "    # Create summary of evaluations\n",
    "    evaluations_summary = \"\"\n",
    "    for i, result in enumerate(results):\n",
    "        evaluations_summary += f\"Query {i+1}: {result['query']}\\n\"\n",
    "        if 'crag_evaluation' in result and 'overall_score' in result['crag_evaluation']:\n",
    "            crag_score = result['crag_evaluation'].get('overall_score', 'N/A')\n",
    "            evaluations_summary += f\"CRAG score: {crag_score}\\n\"\n",
    "        if 'standard_evaluation' in result and 'overall_score' in result['standard_evaluation']:\n",
    "            std_score = result['standard_evaluation'].get('overall_score', 'N/A')\n",
    "            evaluations_summary += f\"Standard RAG score: {std_score}\\n\"\n",
    "        evaluations_summary += f\"Comparison summary: {result['comparison'][:200]}...\\n\\n\"\n",
    "    \n",
    "    # User prompt for the analysis\n",
    "    user_prompt = f\"\"\"\n",
    "    Based on the following evaluations comparing CRAG vs standard RAG across {len(results)} queries, \n",
    "    provide an overall analysis of these two approaches:\n",
    "    \n",
    "    {evaluations_summary}\n",
    "    \n",
    "    Please provide a comprehensive analysis of the relative strengths and weaknesses of CRAG \n",
    "    compared to standard RAG, focusing on when and why one approach outperforms the other.\n",
    "    \"\"\"\n",
    "    \n",
    "    try:\n",
    "        # Generate the overall analysis using GPT-4\n",
    "        response = client.chat.completions.create(\n",
    "            model=\"gpt-4\",\n",
    "            messages=[\n",
    "                {\"role\": \"system\", \"content\": system_prompt},\n",
    "                {\"role\": \"user\", \"content\": user_prompt}\n",
    "            ],\n",
    "            temperature=0\n",
    "        )\n",
    "        \n",
    "        return response.choices[0].message.content.strip()\n",
    "    except Exception as e:\n",
    "        print(f\"Error generating overall analysis: {e}\")\n",
    "        return f\"Error generating overall analysis: {str(e)}\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation of CRAG with Test Queries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Path to the AI information PDF document\n",
    "pdf_path = \"data/AI_Information.pdf\"\n",
    "\n",
    "# Run comprehensive evaluation with multiple AI-related queries\n",
    "test_queries = [\n",
    "    \"How does machine learning differ from traditional programming?\",\n",
    "]\n",
    "\n",
    "# Optional reference answers for better quality evaluation\n",
    "reference_answers = [\n",
    "    \"Machine learning differs from traditional programming by having computers learn patterns from data rather than following explicit instructions. In traditional programming, developers write specific rules for the computer to follow, while in machine learning\",\n",
    "]\n",
    "\n",
    "# Run the full evaluation comparing CRAG vs standard RAG\n",
    "evaluation_results = run_crag_evaluation(pdf_path, test_queries, reference_answers)\n",
    "print(\"\\n=== Overall Analysis of CRAG vs Standard RAG ===\")\n",
    "print(evaluation_results[\"overall_analysis\"])\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv-new-specific-rag",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
